"""
This file will compute the min, max, mean, and standard deviation of each datasets
in `pretrain_datasets.json` or `pretrain_datasets.json`.
"""

import json
import argparse
import os

# from multiprocessing import Pool, Manager

import tensorflow as tf
import numpy as np
from tqdm import tqdm

from data.vla_dataset import VLADataset
from data.hdf5_vla_dataset import HDF5VLADataset
from data.preprocess import generate_json_state


# Process each dataset to get the statistics
@tf.autograph.experimental.do_not_convert
def process_dataset(name_dataset_pair):
    # print(f"PID {os.getpid()} processing {name_dataset_pair[0]}")
    dataset_iter = name_dataset_pair[1]

    MAX_EPISODES = 100000
    EPS = 1e-8
    # For debugging
    # MAX_EPISODES = 10
    episode_cnt = 0
    state_sum = 0
    state_sum_sq = 0
    z_state_sum = 0
    z_state_sum_sq = 0
    state_cnt = 0
    nz_state_cnt = None
    state_max = None
    state_min = None
    for episode in dataset_iter:
        episode_cnt += 1
        if episode_cnt % 1000 == 0:
            print(f"Processing episodes {episode_cnt}/{MAX_EPISODES}")
        if episode_cnt > MAX_EPISODES:
            break
        episode_dict = episode["episode_dict"]
        dataset_name = episode["dataset_name"]

        res_tup = generate_json_state(episode_dict, dataset_name)
        states = res_tup[1]

        # Convert to numpy
        states = states.numpy()

        # Zero the values that are close to zero
        z_states = states.copy()
        z_states[np.abs(states) <= EPS] = 0
        # Compute the non-zero count
        if nz_state_cnt is None:
            nz_state_cnt = np.zeros(states.shape[1])
        nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)

        # Update statistics
        state_sum += np.sum(states, axis=0)
        state_sum_sq += np.sum(states**2, axis=0)
        z_state_sum += np.sum(z_states, axis=0)
        z_state_sum_sq += np.sum(z_states**2, axis=0)
        state_cnt += states.shape[0]
        if state_max is None:
            state_max = np.max(states, axis=0)
            state_min = np.min(states, axis=0)
        else:
            state_max = np.maximum(state_max, np.max(states, axis=0))
            state_min = np.minimum(state_min, np.min(states, axis=0))

    # Add one to avoid division by zero
    nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))

    result = {
        "dataset_name":
        name_dataset_pair[0],
        "state_mean": (state_sum / state_cnt).tolist(),
        "state_std":
        np.sqrt(
            np.maximum(
                (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
                np.zeros_like(state_sum_sq),
            )).tolist(),
        "state_min":
        state_min.tolist(),
        "state_max":
        state_max.tolist(),
    }

    return result


def process_hdf5_dataset(vla_dataset):
    EPS = 1e-8
    episode_cnt = 0
    state_sum = 0
    state_sum_sq = 0
    z_state_sum = 0
    z_state_sum_sq = 0
    state_cnt = 0
    nz_state_cnt = None
    state_max = None
    state_min = None
    for i in tqdm(range(len(vla_dataset))):
        episode = vla_dataset.get_item(i, state_only=True)
        episode_cnt += 1

        states = episode["state"]

        # Zero the values that are close to zero
        z_states = states.copy()
        z_states[np.abs(states) <= EPS] = 0
        # Compute the non-zero count
        if nz_state_cnt is None:
            nz_state_cnt = np.zeros(states.shape[1])
        nz_state_cnt += np.sum(np.abs(states) > EPS, axis=0)

        # Update statistics
        state_sum += np.sum(states, axis=0)
        state_sum_sq += np.sum(states**2, axis=0)
        z_state_sum += np.sum(z_states, axis=0)
        z_state_sum_sq += np.sum(z_states**2, axis=0)
        state_cnt += states.shape[0]
        if state_max is None:
            state_max = np.max(states, axis=0)
            state_min = np.min(states, axis=0)
        else:
            state_max = np.maximum(state_max, np.max(states, axis=0))
            state_min = np.minimum(state_min, np.min(states, axis=0))

    # Add one to avoid division by zero
    nz_state_cnt = np.maximum(nz_state_cnt, np.ones_like(nz_state_cnt))

    result = {
        "dataset_name":
        vla_dataset.get_dataset_name(),
        "state_mean": (state_sum / state_cnt).tolist(),
        "state_std":
        np.sqrt(
            np.maximum(
                (z_state_sum_sq / nz_state_cnt) - (z_state_sum / state_cnt)**2 * (state_cnt / nz_state_cnt),
                np.zeros_like(state_sum_sq),
            )).tolist(),
        "state_min":
        state_min.tolist(),
        "state_max":
        state_max.tolist(),
    }

    return result


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Multiprocessing currently with bugs
    # parser.add_argument('--n_workers', type=int, default=1,
    #                     help="Number of parallel workers.")
    parser.add_argument(
        "--dataset_type",
        type=str,
        default="pretrain",
        help="Whether to load the pretrain dataset or finetune dataset.",
    )
    parser.add_argument(
        "--save_path",
        type=str,
        default="configs/dataset_stat.json",
        help="JSON file path to save the dataset statistics.",
    )
    parser.add_argument(
        "--skip_exist",
        action="store_true",
        help="Whether to skip the existing dataset statistics.",
    )
    parser.add_argument(
        "--hdf5_dataset",
        action="store_true",
        help="Whether to load the dataset from the HDF5 files.",
    )
    args = parser.parse_args()

    if args.hdf5_dataset:
        vla_dataset = HDF5VLADataset()
        dataset_name = vla_dataset.get_dataset_name()

        try:
            with open(args.save_path, "r") as f:
                results = json.load(f)
        except FileNotFoundError:
            results = {}
        if args.skip_exist and dataset_name in results:
            print(f"Skipping existed {dataset_name} dataset statistics")
        else:
            print(f"Processing {dataset_name} dataset")
            result = process_hdf5_dataset(vla_dataset)
            results[result["dataset_name"]] = result
            with open(args.save_path, "w") as f:
                json.dump(results, f, indent=4)
        print("All datasets have been processed.")
        os._exit(0)

    vla_dataset = VLADataset(seed=0, dataset_type=args.dataset_type, repeat=False)
    name_dataset_pairs = vla_dataset.name2dataset.items()
    # num_workers = args.n_workers

    for name_dataset_pair in tqdm(name_dataset_pairs):
        try:
            with open(args.save_path, "r") as f:
                results = json.load(f)
        except FileNotFoundError:
            results = {}

        if args.skip_exist and name_dataset_pair[0] in results:
            print(f"Skipping existed {name_dataset_pair[0]} dataset statistics")
            continue
        print(f"Processing {name_dataset_pair[0]} dataset")

        result = process_dataset(name_dataset_pair)

        results[result["dataset_name"]] = result

        # Save the results in the json file after each dataset (for resume)
        with open(args.save_path, "w") as f:
            json.dump(results, f, indent=4)

    print("All datasets have been processed.")

    # with Manager() as manager:
    #     # Create shared dictionary and lock through the manager, accessible by all processes
    #     progress = manager.dict(processed=0, results={})
    #     progress_lock = manager.Lock()

    #     # Callback function to update progress
    #     def update_progress(result):
    #         with progress_lock:
    #             progress['processed'] += 1
    #             print(f"{result['dataset_name']} - {progress['processed']}/{len(name_dataset_pairs)} datasets have been processed")
    #             # Append the result to the shared dictionary
    #             progress['results'][result["dataset_name"]] = result

    #     with Pool(num_workers) as p:
    #         for name_dataset_pair in name_dataset_pairs:
    #             p.apply_async(process_dataset, args=(name_dataset_pair,), callback=update_progress)

    #         # Close the pool and wait for the work to finish
    #         p.close()
    #         p.join()

    # # Save the results in the json file
    # with open(args.save_path, 'w') as f:
    #     json.dump(progress['results'], f, indent=4)
